import argparse
import os
import sys
import json
from factory import *

from ray import tune, air
from ray.air import session
from ray.tune import CLIReporter
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.optuna import OptunaSearch
from ray.tune.schedulers import ASHAScheduler

import torch
import torch.nn.functional as F


def sample_parameter(param):
    if "min" in param.keys():
        if param["distribution"] == "loguniform":
            return tune.loguniform(param["min"], param["max"])
        else:
            return tune.uniform(param["min"], param["max"])
    if "values" in param.keys():
        return tune.choice(param["values"])


def run_sample(config):
    experiment_dir = make_dir(config["root_dir"], config["experiment_name"])

    device = set_device()

    reporter = make_reporter(experiment_dir=experiment_dir, cfg=config)

    set_seed(cfg["random_seed"])

    data, data_loaders = make_dataset(root_dir=config["root_dir"], dataset_name=config["dataset_name"], params=config["loader"])

    model = make_model(model_name=config["model_name"], data=data, device=device, params=config["model"])

    optimizer = make_optimizer(model=model, params=config["optimizer"])

    experiment = make_experiment(model=model, optimizer=optimizer, reporter=reporter, device=device,
                                 experiment_dir=experiment_dir, params=config["experiment"])

    session_results = experiment.train(**data_loaders)

    print(f"session_results: {session_results}")
    session.report(session_results)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--root_dir', type=str, default="/Users/muberra/Documents/research-repo/temporal-interaction-classification")
    parser.add_argument('-d', '--dataset_name', type=str, default="yelpchi")
    parser.add_argument('-c', "--config_file", type=str, default="profile_builder/tune_trial3.json")
    args = parser.parse_args()
    with open(args.root_dir + "/configs/" + args.config_file, mode="r") as f:
        cfg = json.load(f)

    cfg["root_dir"] = args.root_dir
    cfg["dataset_name"] = args.dataset_name

    root_dir = cfg["root_dir"]
    dataset_name = cfg["dataset_name"]
    model_name = cfg["model_name"]
    experiment_name = cfg["experiment_name"]
    random_seed = cfg["random_seed"]

    cfg_loader = {}
    for i in cfg["loader"]:
        cfg_loader[i["name"]] = sample_parameter(i)
    cfg_model = {}
    for i in cfg["model"]:
        cfg_model[i["name"]] = sample_parameter(i)
    cfg_optimizer = {}
    for i in cfg["optimizer"]:
        cfg_optimizer[i["name"]] = sample_parameter(i)
    cfg_experiment = {}
    for i in cfg["experiment"]:
        cfg_experiment[i["name"]] = sample_parameter(i)
    _cfg = {
        "root_dir": root_dir,
        "dataset_name": dataset_name,
        "model_name": model_name,
        "experiment_name": experiment_name,
        "random_seed": random_seed,
        "loader": cfg_loader,
        "model": cfg_model,
        "optimizer": cfg_optimizer,
        "experiment": cfg_experiment,
    }
    trainable = tune.with_parameters(run_sample)

    if torch.cuda.is_available():
        trainable_with_resources = tune.with_resources(trainable, {"gpu": 1})
    else:
        trainable_with_resources = tune.with_resources(trainable, {"cpu": 2})

    algo = OptunaSearch()
    local_dir = os.path.join(root_dir, "results", experiment_name)
    tuner = tune.Tuner(
        trainable_with_resources,
        tune_config=tune.TuneConfig(
            metric="MCC",
            mode="max",
            search_alg=ConcurrencyLimiter(algo, max_concurrent=8),
            num_samples=250,
        ),
        run_config=air.RunConfig(
            local_dir=local_dir + "/tuners/",
            name=time.strftime('%Y_%m_%d_%H_%M_%S') + '_rid_' + str(random.random()).split('.')[1],
            stop={"training_iteration": 5},
        ),
        param_space=_cfg,
    )
    results = tuner.fit()

    best_config = results.get_best_result().config
    fname = "best_config_" + dataset_name + "_" + model_name + "_seed_" + str(random_seed) + "_at_" + time.strftime("%Y_%m_%d_%H_%M_%S") + "_rid_" + str(random.random()).split(".")[1] + ".json"
    with open(os.path.join(local_dir, fname), "w") as outfile:
        json.dump(best_config, outfile)

    best_metrics = results.get_best_result().metrics
    fname = "best_result_" + dataset_name + "_" + model_name + "_seed_" + str(random_seed) + "_at_" + time.strftime(
        "%Y_%m_%d_%H_%M_%S") + "_rid_" + str(random.random()).split(".")[1] + ".json"
    with open(os.path.join(local_dir, fname), "w") as outfile:
        json.dump(best_metrics, outfile)
